# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch

# import contextlib
# import logging

import torch
# from torch import _C
from torch.cuda import _lazy_call
from torch.nn import Identity
# from torch.cuda import device as device_ctx_manager
from torch.utils.checkpoint import detach_variable

# from megatron.core.parallel_state import (
#     get_expert_model_parallel_rank,
#     get_expert_tensor_parallel_rank,
#     get_tensor_model_parallel_rank,
# )
from megatron.core.utils import is_te_min_version, safely_set_viewless_tensor_data

from megatron.core.tensor_parallel.utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks


from megatron.core.tensor_parallel.random import (CheckpointFunction, get_cuda_rng_tracker,
                                                   _set_cuda_rng_state)

    
# HACK(huang.huang): recompute-variance for [somefunc+fa] and [somefunc+linear], 
# which can save a forward for fa/linear when backward recompute 
# 2025.4.2: support list of linear as last_function, and args "mid_function" to support complex situations
class IdentityTupleOp(torch.nn.Module):
    """
    This is a placeholder for IdentityTupleOp(*args) -> args,
    """

    def __init__(self,):
        super().__init__()

    def forward(self, *args):
        return args


class CheckpointFunctionVirance(CheckpointFunction):
    """Checkpoint Function

    This function is adapted from torch.utils.checkpoint with two main changes:
    1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
    2) the states in the model parallel tracker are also properly tracked/set/reset.
    """

    # pylint: disable=missing-function-docstring
    @staticmethod
    def forward(ctx, run_function, last_function, mid_function, distribute_saved_activations, *args):
        """Forward pass."""
        if not isinstance(last_function, tuple):
            last_function = (last_function, )
        mid_function = tuple(IdentityTupleOp() for _ in last_function) if mid_function is None else mid_function       
        ctx.run_function = run_function
        ctx.last_function = last_function 
        ctx.mid_function = mid_function
        ctx.distribute_saved_activations = distribute_saved_activations

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
        ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
        ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        with torch.no_grad():
            outputs = run_function(*args)
            outputs = outputs if isinstance(outputs, tuple) else (outputs, )
            total_outputs = []
            for i, func in enumerate(last_function):
                outputs_f = mid_function[i](*outputs)
                outputs_f = outputs_f if isinstance(outputs_f, tuple) else (outputs_f, )
                outputs_f = func(*outputs_f)
                total_outputs.append(outputs_f)
            if len(total_outputs)==1:
                #maintain original behavior when only one last_function 
                total_outputs=total_outputs[0] 
            else:
                flat_outputs = []
                for outputs_f in total_outputs:
                    if isinstance(outputs_f, tuple):
                        #Manually remove bias_out which is 'None', and assign 'None' to grad-bias in the corresponding backward direction
                        outputs_f = tuple([x for x in outputs_f if x is not None])         
                    flat_outputs.append(outputs_f)   
                total_outputs = flat_outputs
                #The reentrant version does not consider tensors in nested structures (e.g., custom objects, lists, dicts, etc) 
                # as participating in autograd, while the non-reentrant version does
                total_outputs = sum( [x if isinstance(x, tuple) else (x,) for x in total_outputs ], tuple()) 
        # Divide hidden states across model parallel group and only keep
        # the chunk corresponding to the current rank.
        if distribute_saved_activations:
            ctx.input_0_shape = args[0].data.shape
            safely_set_viewless_tensor_data(
                args[0], split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)
            )

        # Store everything.
        ctx.inputs = [arg if not torch.is_tensor(arg) else None for arg in args]
        tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args]
        ctx.save_for_backward(*tensor_inputs)

        return total_outputs

    # pylint: disable=missing-function-docstring
    @staticmethod
    def backward(ctx, *args):
        """Backward pass."""
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad(), "
                "please use .backward() if possible"
            )
        # inputs = ctx.saved_tensors
        inputs = tuple(
            t if t is not None else arg for (t, arg) in zip(ctx.saved_tensors, ctx.inputs)
        )
        if ctx.distribute_saved_activations:
            safely_set_viewless_tensor_data(
                inputs[0], gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)
            )

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
        bwd_cuda_rng_state = torch.cuda.get_rng_state()
        bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        # Set the states to what it used to be before the forward pass.
        torch.set_rng_state(ctx.fwd_cpu_rng_state)
        _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)

        # Compute the forward pass.
        detached_inputs = detach_variable(inputs)
        with torch.enable_grad():
            outputs = ctx.run_function(*detached_inputs)
            outputs = outputs if isinstance(outputs, tuple) else (outputs, )
            total_outputs = []
            for i,func in enumerate(ctx.mid_function):
                outputs_f = func(*outputs)
                if isinstance(outputs_f, torch.Tensor):
                    outputs_f = [outputs_f,]
                total_outputs.append(outputs_f)
        # Set the states back to what it was at the start of this function.
        torch.set_rng_state(bwd_cpu_rng_state)
        _set_cuda_rng_state(bwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)


        total_grad_input = []
        for i,func in enumerate(ctx.last_function):
            if isinstance(func, Identity):
                grad_input_f = args[i]
            else:
                # Assign 'None' to grad_bias to correspond to the operation of removing 'none' during forward
                grad_out_bias = args[i] if isinstance(args[i], tuple) else (args[i], None)
                grad_input_f = func.backward_custom(*total_outputs[i], *grad_out_bias)
            if isinstance(grad_input_f, torch.Tensor):
                grad_input_f = (grad_input_f,)
            total_grad_input.append(grad_input_f)

        total_outputs_with_grad = []
        total_args_with_grad = []
        for j, outputs in enumerate(total_outputs):
            outputs_with_grad = []
            args_with_grad = []
            for i, output in enumerate(outputs):
                if torch.is_tensor(output) and output.requires_grad:
                    outputs_with_grad.append(output)
                    args_with_grad.append(total_grad_input[j][i])    
            total_outputs_with_grad += outputs_with_grad
            total_args_with_grad += args_with_grad
        torch.autograd.backward(total_outputs_with_grad, total_args_with_grad)
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
        return (None, None, None, None) + grads
    
def checkpointVirance(run_function, last_function, distribute_saved_activations, *args, mid_function=None):
    """Checkpoint a model or part of the model.
    This has been directly copied from torch.utils.checkpoint."""
    return CheckpointFunctionVirance.apply(run_function, last_function, mid_function, distribute_saved_activations, *args)



class CheckpointFunctionViranceAttention(CheckpointFunction):
    """Checkpoint Function

    This function is adapted from torch.utils.checkpoint with two main changes:
    1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
    2) the states in the model parallel tracker are also properly tracked/set/reset.
    """

    # pylint: disable=missing-function-docstring
    @staticmethod
    def forward(ctx, run_function, last_function, distribute_saved_activations, *args):
        """Forward pass."""
        ctx.run_function = run_function
        ctx.last_function = last_function 
        ctx.distribute_saved_activations = distribute_saved_activations

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
        ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
        ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        with torch.no_grad():
            outputs = run_function(*args)
            outputs = last_function.forward_before_fa(*outputs[:4], **outputs[4])
            outputs = last_function.forward_fa(*outputs) 
            #outputs: Union[output=Union[Tensor output, Tensor logsumexp, Tensor dropout_mask], 
            # qkv_format, indices_q, batch_size, attn_mask_type, max_seqlen_q, q_shape, v_shape]
            core_attn_out = last_function.forward_after_fa(*outputs)
        # Divide hidden states across model parallel group and only keep
        # the chunk corresponding to the current rank.
        if distribute_saved_activations:
            ctx.input_0_shape = args[0].data.shape
            safely_set_viewless_tensor_data(
                args[0], split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)
            )

        # Store everything.
        ctx.save_for_backward(*args, *outputs[0])
        (ctx.qkv_format, ctx.indices_q, ctx.batch_size, 
         ctx.attn_mask_type, ctx.max_seqlen_q, ctx.q_shape, ctx.v_shape) = outputs[1:]

        return core_attn_out

# pylint: disable=missing-function-docstring
    @staticmethod
    def backward(ctx, *args):
        """Backward pass."""
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad(), "
                "please use .backward() if possible"
            )
        inputs = ctx.saved_tensors
        fa_output = inputs[-3:]
        inputs = inputs[:-3]
        if ctx.distribute_saved_activations:
            safely_set_viewless_tensor_data(
                inputs[0], gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)
            )

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
        bwd_cuda_rng_state = torch.cuda.get_rng_state()
        bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        # Set the states to what it used to be before the forward pass.
        torch.set_rng_state(ctx.fwd_cpu_rng_state)
        _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)

        # Compute the forward pass.
        detached_inputs = detach_variable(inputs)
        detached_ori_outputs = detach_variable(fa_output)
        detached_ori_outputs[0].requires_grad = True #only 0 element need grad in output of FA: [Tensor output, Tensor logsumexp, Tensor dropout_mask]
        # ori_outputs is not requires_grad
        with torch.enable_grad():
            outputs_before_fa = ctx.run_function(*detached_inputs) 
            # outputs_before_fa: query, key, value, attention_mask, {"attn_mask_type":attn_mask_type, "attention_bias":attention_bias, "packed_seq_params":packed_seq_params}
            outputs_before_fa = ctx.last_function.forward_before_fa(*outputs_before_fa[:4], **outputs_before_fa[4])
            outputs = ctx.last_function.forward_after_fa(detached_ori_outputs, 
                                                         ctx.qkv_format, ctx.indices_q,  
                                                         ctx.batch_size, ctx.attn_mask_type, 
                                                         ctx.max_seqlen_q, ctx.q_shape, ctx.v_shape)
        # Set the states back to what it was at the start of this function.
        torch.set_rng_state(bwd_cpu_rng_state)
        _set_cuda_rng_state(bwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)

        
        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        # filter out non tensor outputs for backward pass
        outputs, args = zip(*filter(lambda x: torch.is_tensor(x[0]), zip(outputs, args)))
        torch.autograd.backward(outputs, args)
        
        #costum bwd fa
        with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False):
            with torch.no_grad():
                grad_input = torch.ops.aten._scaled_dot_product_attention_flash_musa_backward(
                    # ori_outputs[0][0].grad,
                    detached_ori_outputs[0].grad,
                    *outputs_before_fa[:3], #q, k, v
                    *detached_ori_outputs, #(Tensor output, Tensor logsumexp, Tensor dropout_mask)
                    is_causal="causal" in ctx.attn_mask_type, #causal same as fwd
                ) 
        
        #bwd before fa: for qkv
        torch.autograd.backward(outputs_before_fa[:3], grad_input)
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
        return (None, None, None) + grads
    

def checkpointViranceAttention(run_function, last_function, distribute_saved_activations, *args):
    """Checkpoint a model or part of the model.
    This has been directly copied from torch.utils.checkpoint."""
    return CheckpointFunctionViranceAttention.apply(run_function, last_function, distribute_saved_activations, *args)
# HACK(huang.huang)


from transformer_engine.musa.pytorch.utils import add_attr
from megatron.core import tensor_parallel
add_attr(tensor_parallel, 'checkpointVirance', checkpointVirance)
add_attr(tensor_parallel, 'checkpointViranceAttention', checkpointViranceAttention)